-
-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix GaussAdjoint with callbacks #1060
base: master
Are you sure you want to change the base?
Conversation
Ah, this was what #1034 was trying to address lol |
That would've been good to know 😅 |
07d6bdf
to
33a57d2
Compare
Bump? |
f248d52
to
ed0f98f
Compare
hmm ... it's probably close but it currently fails for the callback where the correction should be 0: g(sol) = sum(sol)
function dg!(out, u, p, t, i)
(out .= 1)
end
@testset "callbacks with no effect" begin
condition(u, t, integrator) = t == 5
affect!(integrator) = integrator.u[1] += 0.0
cb = DiscreteCallback(condition, affect!, save_positions = (false, false))
tstops = [5.0]
test_discrete_callback(cb, tstops, g, dg!)
end I think it's the line: if sensealg isa GaussAdjoint
@assert integrator.f.f isa ODEGaussAdjointSensitivityFunction
@show integrator.f.f.integrating_cb.affect!.integrand_values.integrand dgrad
integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad in callback_tracking.jl, whch for the example above prints: integrator.f.f.integrating_cb.affect!.integrand_values.integrand = [4.318834073956617, -23.50595577533703, 3.507883353046969, -26.70725798446173]
dgrad = [-0.0, -0.0, -0.0, -0.0] |
Is that with the VJPs as Enzyme or ReverseDiff? IIUC it always defaults to ReverseDiff right now? |
@jClugstor did your PR look at GaussAdjoint? |
No, but there is a '@test_broken' for 'GaussAdjoint' with a callback in that PR, so if callbacks get fixed that will need to change. |
No description provided.